# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:

code from https://github.com/QwenLM/Qwen-7B/blob/main/eval/EVALUATION.md

usage:
git clone https://github.com/openai/human-eval
$ pip install -e human-eval
evaluate_functional_correctness sample-output-file
"""
import argparse

import jsonlines
import torch
import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig


def post_decode(tokens_list, tokenizer, raw_text_len):
    sents = []
    # print(len(tokens_list))
    for tokens in tokens_list:
        tokens = tokens.cpu().numpy().tolist()
        sent = tokenizer.decode(tokens[raw_text_len:])
        sent = sent.split('<|endoftext|>')[0]
        sent = sent.split('\n\n\n')[0]
        sent = sent.split("\n\n")[0]
        sent = sent.split("def ")[0]
        sents.append(sent)
    return sents


def generate_sample(model, tokenizer, input_txt):
    input_ids = tokenizer.encode(input_txt)
    raw_text_len = len(input_ids)
    context_enc = torch.tensor([input_ids]).to(model.device)
    print(f"Input text: {input_txt}\n")
    outputs = model.generate(context_enc)
    output_text = post_decode(outputs, tokenizer, raw_text_len)[0]
    print(f"\nOutput text: \n{output_text}\n")
    return output_text


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Test HF checkpoint.')
    parser.add_argument("-c", "--checkpoint-path", type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
    parser.add_argument("-f", "--sample-input-file", type=str, default=None, help="data path to HumanEval.jsonl")
    parser.add_argument("-o", "--sample-output-file", type=str, default="HumanEval_res.jsonl")

    args = parser.parse_args()
    print('Loading tokenizer ...')
    tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)

    print('Loading model ...')
    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint_path, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval()
    try:
        model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
        model.generation_config.do_sample = False  # use greedy decoding
    except:
        print("GenerationConfig not found, use default config.")

    f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
    f = jsonlines.open(args.sample_input_file)
    with f_output as output:
        for jobj in tqdm.tqdm(f, desc='task_idx'):
            prompt = jobj['prompt']
            task_id = jobj['task_id']
            gen_sents = generate_sample(model, tokenizer, prompt)
            gen_jobjs = {'task_id': task_id, "completion": gen_sents}
            output.write(gen_jobjs)
    f_output.close()
